{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " # MLP example using PySNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from torchvision import transforms\n",
    "from tqdm import tqdm\n",
    "\n",
    "from pysnn.connection import Linear\n",
    "from pysnn.neuron import FedeNeuron, Input\n",
    "from pysnn.learning import FedeSTDP\n",
    "from pysnn.encoding import PoissonEncoder\n",
    "from pysnn.network import SNNNetwork\n",
    "from pysnn.datasets import AND, BooleanNoise, Intensity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Parameter defintions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Architecture\n",
    "n_in = 6\n",
    "n_hidden = 10\n",
    "n_out = 1\n",
    "\n",
    "# Data\n",
    "duration = 200\n",
    "intensity = 40\n",
    "num_workers = 0\n",
    "batch_size = 1\n",
    "\n",
    "# Neuronal Dynamics\n",
    "thresh = 0.8\n",
    "v_rest = 0\n",
    "alpha_v = 0.2\n",
    "tau_v = 5\n",
    "alpha_t = 1.0\n",
    "tau_t = 5\n",
    "duration_refrac = 5\n",
    "dt = 1\n",
    "delay = 3\n",
    "i_dynamics = (dt, alpha_t, tau_t)\n",
    "n_dynamics = (thresh, v_rest, alpha_v, alpha_v, dt, duration_refrac, tau_v, tau_t)\n",
    "c_dynamics = (batch_size, dt, delay)\n",
    "\n",
    "# Learning\n",
    "lr = 0.0001\n",
    "w_init = 0.5\n",
    "a = 0.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Network definition\n",
    " The API is mostly the same aas for regular PyTorch. The main differences are that layers are composed of a `Neuron` and `Connection` type,\n",
    " and the layer has to be added to the network by calling the `add_layer` method. Lastly, all objects return both a\n",
    " spike (or activation potential) object and a trace object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Network(SNNNetwork):\n",
    "    def __init__(self):\n",
    "        super(Network, self).__init__()\n",
    "\n",
    "        # Input\n",
    "        self.input = Input(\n",
    "            (batch_size, 1, n_in), *i_dynamics, update_type=\"exponential\"\n",
    "        )\n",
    "\n",
    "        # Layer 1\n",
    "        self.mlp1_c = Linear(n_in, n_hidden, *c_dynamics)\n",
    "        self.neuron1 = FedeNeuron((batch_size, 1, n_hidden), *n_dynamics)\n",
    "        self.add_layer(\"fc1\", self.mlp1_c, self.neuron1)\n",
    "\n",
    "        # Layer 2\n",
    "        self.mlp2_c = Linear(n_hidden, n_out, *c_dynamics)\n",
    "        self.neuron2 = FedeNeuron((batch_size, 1, n_out), *n_dynamics)\n",
    "        self.add_layer(\"fc2\", self.mlp2_c, self.neuron2)\n",
    "\n",
    "    def forward(self, input):\n",
    "        x, t = self.input(input)\n",
    "\n",
    "        # Layer 1\n",
    "        x, t = self.mlp1_c(x, t)\n",
    "        x, t = self.neuron1(x, t)\n",
    "\n",
    "        # Layer out\n",
    "        x, t = self.mlp2_c(x, t)\n",
    "        x, t = self.neuron2(x, t)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Dataset\n",
    " Simple Boolean AND dataset, generated to match the input dimensions of the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_transform = transforms.Compose(\n",
    "    [\n",
    "        # BooleanNoise(0.2, 0.8),\n",
    "        Intensity(intensity)\n",
    "    ]\n",
    ")\n",
    "lbl_transform = transforms.Lambda(lambda x: x * intensity)\n",
    "\n",
    "train_dataset = AND(\n",
    "    data_encoder=PoissonEncoder(duration, dt),\n",
    "    data_transform=data_transform,\n",
    "    lbl_transform=lbl_transform,\n",
    "    repeats=n_in / 2,\n",
    ")\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "Perform training and write network graph to TensorBoard (requires SummaryWriterX and tensorboard install)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\")\n",
    "net = Network()\n",
    "\n",
    "# Add graph to tensorboard\n",
    "logger = SummaryWriter()\n",
    "input = next(iter(train_dataloader))\n",
    "input = input[0][:, :, :, 0]\n",
    "logger.add_graph(net, input)\n",
    "\n",
    "# Learning rule definition\n",
    "layers = net.layer_state_dict()\n",
    "learning_rule = FedeSTDP(layers, lr, w_init, a)\n",
    "\n",
    "# Training loop\n",
    "out = []\n",
    "for batch in tqdm(train_dataloader):\n",
    "    single_out = []\n",
    "    sample, label = batch\n",
    "\n",
    "    # Iterate over input's time dimension\n",
    "    for idx in range(sample.shape[-1]):\n",
    "        input = sample[:, :, :, idx]\n",
    "        single_out.append(net(input))\n",
    "\n",
    "        learning_rule.step()\n",
    "\n",
    "    net.reset_state()\n",
    "    out.append(torch.stack(single_out, dim=-1))\n",
    "\n",
    "print(out[0].shape)"
   ]
  }
 ],
 "metadata": {
  "file_extension": ".py",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
